{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = torch.unsqueeze(torch.linspace(-1 , 1, 100), dim=1)\n",
    "y = x.pow(2)+0.2*torch.rand(x.size())\n",
    "x, y = Variable(x), Variable(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAcg0lEQVR4nO3df4wcZ33H8fc3x0VsCvUFYiC+xLUrpQaqAA5HQjEqiSl1korGpK1IQPxIkay0pCKRanGoEkWiVUytFooIRG4aAS0irkoaXBLq0hqayjTgM3F+YzDhR+6cEkNyUMgJbPPtH7ubzO3NzD67M7PzYz8vyfLt7uzus7Nz33vmO9/neczdERGR+jul7AaIiEg+FNBFRBpCAV1EpCEU0EVEGkIBXUSkIZ5R1hufccYZvm7durLeXkSklg4ePPgDd18d91hpAX3dunXMzc2V9fYiIrVkZt9NekwpFxGRhlBAFxFpCAV0EZGGUEAXEWkIBXQRkYYorcpFRGTc3Hb3Ajv3Hubo4hJrplps37KBrRunc3t9BXQRkRG47e4F3nPrfSwdPwnAwuIS77n1PoDcgrpSLiIiI7Bz7+GngnnX0vGT7Nx7OLf3qFUPvejTFRGRohxdXBro/mHUJqCP4nRFRCRv3Y5o0lJCa6Zaub1XbVIuozhdERHJU7cjupDQC29NTrB9y4bc3q82PfRRnK6IiOQpriPaNT3OVS5rplqxf+XyPF0REclTUofTgP2zm3N/v9qkXLZv2UBrcmLZfXmfroiI5Cmpw1lUR7Q2AX3rxmmuv/xcpqdaGDDVmuSZk6dw3e5DbNqxj9vuXii7iSIiy4y6I1qblAu0g/rWjdOqeBGRWujGo1GVW9cqoHelVbwooItIlXQ7oqNQm5RLlCpeRERWqmVAH/WFBhGROqhlQFfFi4jISrXMocddaLjohavZufcw1+0+pHleRGQs1TKgw/ILDap6EZEqKWsiwVqmXHppnhcRqYro/C3O0x3MUYyVaURAV9WLiFRFmR3MRgR0Vb2ISFWU2cFsREBX1YuIVEWZHcxGBPTeeV6mp1pcf/m5uiAqIiNXZgeztlUuvUY5vFZEJMmo52+JakxAFxGpirI6mI1IuYiIiAK6iEhj9A3oZnazmT1mZvcnPG5m9mEzO2Jm95rZefk3U0RE+gnpoX8cuDjl8UuAczr/tgEfy94sEREZVN+A7u53Ao+nbHIZ8ElvuwuYMrMz82qgiIiEyaPKZRp4JHJ7vnPfo70bmtk22r141q5dm8NbxytrYhwRkTLlcVHUYu7zuA3dfZe7z7j7zOrVq3N465XKnBhHRKRMefTQ54GzI7fPAo7m8LpD0XqjIlKGKmQG8uih7wHe2ql2eSXwI3dfkW4ZFc28KCKjVpXMQN8eupl9GrgQOMPM5oE/ByYB3P1G4A7gUuAI8CRwVVGNDbFmqsVCTPDunRinCn9NRaQZqpIZ6BvQ3f3KPo878M7cWpTR9i0blq1eBE9PjNMN4guLSxhPJ/q1wpGIZFGVzEDj5nLpnRhnVWsSM7h296FlQbz3qq3y7CIyrNDMQNEaOfR/68Zp9s9u5oNvfBk/O/ELnnjyOJBQehOhPLuIDKMqazI0roceFZfXSqMVjkRkGGVOmRvV6IA+SI9bKxyJSBZVWJOhkSmXrn497u6IKK1wJCJN0OgeelzFS/fC6LRKFUWkYRod0KuS1xIRGYVGB3SoRl5LRGQUGh/QRUQGVdeR5AroIiIR3XlZutfekkaSVzHoN7rKRURkUGnzsnRVZTKuXgroIiIRIfOyhAT9Miigi4hEJI1fid5flcm4eimgi4hEhMzLEhL0y6CALiISsXXjNNdffi7TUy2M+JHkVZmMq5eqXEREeiSNX4lWtqxqTfLMyVNYfPJ4ZapcFNBFRAL0ljMuLh2nNTnBB9/4stIDeZdSLiIiAapa2RKlgC4iEqCqlS1RCugiIgGqWtkSpYAuIhKgqpUtUbooSjXnZBCRaqnDdNxjH9BDJ+IREan6dNxjn3Kpw5VrEZEQY99Dr8OVaxEpT51SsmPfQ6/DlWsRKUdVp8lNEhTQzexiMztsZkfMbDbm8VVm9q9mdo+ZPWBmV+Xf1GLU4cq1iJSjbinZvikXM5sAbgBeB8wDB8xsj7s/GNnsncCD7v56M1sNHDazT7n7zwtpdY7qcOVaRMpRt5RsSA79fOCIuz8MYGa3AJcB0YDuwLPNzIBnAY8DJ3Jua2GqfuVaRMqxZqrFQkzwrmpKNiTlMg08Erk937kv6iPAi4CjwH3Au9z9F70vZGbbzGzOzOaOHTs2ZJNFREajbinZkIBuMfd5z+0twCFgDfAy4CNm9ssrnuS+y91n3H1m9erVAzdWRGSUQuZGr5KQlMs8cHbk9lm0e+JRVwE73N2BI2b2beCFwFdzaaWISEnqlJINCegHgHPMbD2wAFwBvKlnm+8BrwX+28yeD2wAHs6zoWVLqkWtU42qiDSbtTvVfTYyuxT4EDAB3Ozuf2lmVwO4+41mtgb4OHAm7RTNDnf/x7TXnJmZ8bm5uYzNH43e6QGg/SE98n9Xa3Ki0qdkIlJvZnbQ3WdiHwsJ6EWoakCP63Hv3Hs49kp3kumpFvtnNxfYShEZV2kBfeyH/kclTdTVO7Cgn6rWqIrIck1LmSqgRySNCpsw4+QAZzJVrVEVGSf9gnUTZ1od+7lcopJ61ifdV9SiJqlyjarIuAiZg6Vuw/pDKKBHJPWsu7Wn053HewvzrWe7uv51F2mKkGBdt2H9IZRyidi+ZUNsNcvC4hI79x5WqaJITYQE67oN6w+hgB4RnahrYXFpWUlib35NAVykukKCdVwHru4pU6VcemzdOM3+2c1MT7VWzG9Q9/yayLgImYOlbsP6Q6iHnqCJ+TWRcZE2LXa/Ud/X7T5U21SqAnqCJubXRMZJXGo0qVRx7ruP85mDC7UvYVTKJUHdps0Ukf6Sql8+/ZVHGlHCqB56Aq1kJNI8aWNNBtm+qhTQU6iaRaRZklKpSaPB65ZiVcqlILfdvcCmHftYP3s7m3bsq+wq4SLjJCmVeuUFZzcixaoeegGaOEeESBOkpVJnfuU5tU+xKqAXIG3Ycd0OEJE6ybIQTRNSrAroBUi6kLKwuMT62dtr+9dfpMqaXpIYQjn0AqRdSEma+U1Esml6SWIIBfQCxF146dXUA0qkLE0vSQyhgF6A3jkikjTxgBIpS9KZ8YTF/xbWrSQxhAJ6QbqTfH17x+88NY96ryYeUCJlaXpJYggF9BHQNAIixUuaPfEvtp7buFkVk5gPsFZmnmZmZnxubq6U9y6DFsUQkTyY2UF3n4l7TGWLI9KEGlcRqTalXEREGkI99BIo/SIiRVBAHzHN8yIiRQlKuZjZxWZ22MyOmNlswjYXmtkhM3vAzP4r32Y2R9o8LyIiWfTtoZvZBHAD8DpgHjhgZnvc/cHINlPAR4GL3f17Zva8ohpcd1qrVESKEtJDPx844u4Pu/vPgVuAy3q2eRNwq7t/D8DdH8u3mc2RNJhIg4xEJKuQgD4NPBK5Pd+5L+rXgNPN7EtmdtDM3ppXA5tGg4xEpCghF0XjJkLoHY30DODlwGuBFvA/ZnaXu39j2QuZbQO2Aaxdu3bw1jaA1ioVkaKEBPR54OzI7bOAozHb/MDdfwr81MzuBF4KLAvo7r4L2AXtkaLDNrruQgYZqbRRRAYVEtAPAOeY2XpgAbiCds486rPAR8zsGcCpwAXAB/Ns6DhRaaPISurk9Nc3h+7uJ4BrgL3AQ8A/ufsDZna1mV3d2eYh4N+Ae4GvAje5+/3FNbvZVNoosly3k7OwuKRFYlIEDSxy9zuAO3ruu7Hn9k5gZ35Nq5+8ehAqbRRZTuv0htFcLjnJsweRVMLowKYd+9QrkUa57e4FNu3Yx/rZ2xOPb3Vywiig5yTPNEnaEnY61ZQmCe0IafxGGAX0nOTZg4hO1B9H+XRpitCOkMZvhNHkXDlZM9ViISZ4D9uD6JY2rp+9fUXRP+hUU5ohtCOUNn5D1S9PU0DPyfYtG5aVGkI+PYi8/1CIVMkgx3fc+A2V+C6nlEtOktYzzHpQ6VRTmizr8a0S3+XUQ89REcvMaaoAabKsx7eqX5ZTQK8BrUcqTTbM8d3NmyfNHzKuKUkFdBGpld68ea9xTkkqoItIrcTlzbumxzwlqYAuIrWSlB83YP/s5tE2pmIU0EWkMEXUiKuUN5nKFkWkEEXNkKhS3mQK6CJSiKJqxIsa89EESrmISCGKrBFXKW889dBFpBCaIXH0FNBFpBDKdY+eUi4li1YBrGpNYgaLTx7XEH+pvd5h/d3j+7rdh9i597CO7wIooJeod8Tb4tLxpx4b91njpBm6uW7NijgaCuglShvxBvFrJmruZ6mjYdcE1fE+GAX0EoVc7Y9uo16O1NUwFS863geni6IlCrnaH91Gcz9LXQ1T8aLjfXAK6CVKWwwaVlYEpPVyQlZOFynLMBUvmut8cEq5lCipCiCpyiVpDotVrUmdmkqlDbOQheZsGZy5J00RX6yZmRmfm5sr5b3rKm4eaIPESf6np1pjP/ucVFPIxc644701OTH2w/zN7KC7z8Q9ph56jUR7OQuLS6nBHHRqKtWUdrETlvfif+/l03zx68dU5RJIPfSa2rRjX+zpaJR66FJFScfuVGuSn534hXrkfaT10IMuiprZxWZ22MyOmNlsynavMLOTZvb7wzZWwvTrfWuItVRV0rG7uHRcVS0Z9Q3oZjYB3ABcArwYuNLMXpyw3QeAvXk3UlZKuzCk6USlyga9qKnUYbiQHvr5wBF3f9jdfw7cAlwWs92fAJ8BHsuxfZIgqQzsQ298GftnNyuYS2UlHbunnzYZu72qWsKFXBSdBh6J3J4HLohuYGbTwBuAzcArkl7IzLYB2wDWrl07aFslYpgyMJEqSDp2gdiqFqUOw4UEdIu5r/dK6oeAd7v7SbO4zTtPct8F7IL2RdHQRkq8QSf517wYUpRBj620Y1fH6PBCAvo8cHbk9lnA0Z5tZoBbOsH8DOBSMzvh7rfl0krJTPNiSN66Qby3hDbLsaWViLIJyaEfAM4xs/VmdipwBbAnuoG7r3f3de6+Dvhn4I8VzKtF82JInqILQMPKU3YdW+Xo20N39xNmdg3t6pUJ4GZ3f8DMru48fmPBbZQcaF4MyUO0V96Pjq3RCxop6u53AHf03BcbyN397dmbJXnTvBiSVdxQ/DQ6tkZPsy2OCa3vKFn1W5AlSsdWORTQx8TWjdNcf/m5TE+1MDT4SAbXL4XSrW+bak3yzMlTuG73IU3lPGKanGuMDFNBoFLH8dHvu05K20G7gxBXS65qqtFSQG+gvIKwSh3HR8h3vX3Lhr7T2W7asW+otUMlH0q5NEy0nMx5+hdzmNNelTqOj5DvOiRtp2qqcqmH3jDDrq4eR7+c4yP0u+6XtlM1VbnUQ2+YpF/MhcWlgS9QDbOwr9RTXt+1qqnKpYDeMGm/gIOmX/TLOT7y+q5VTVUupVwaJu7CVVQ0Lxq9cHrRC1cnLvWlKpfmy/O71nws5dESdA0UMjy7NTmROkhES3+JVFPmJeikXrZunGb/7GamE9IvE2Z9R/ypmmU83Hb3Apt27GP97O0aBNQACugNlpQXPRl4VqZqlmbLs8RVqkEBvcGSLlAl9dx7qZql2TTOoHl0UbThki5Q9Zs1T9UszZE0cljjDJpHAX0MxVU0pFW5SH2lDenXIKDmUZWLrKAJuZpj0459sUG7O5lWv7lZpHrSqlzUQ5dlNCFXs6SlVTTOoHkU0GWZPOeCkfIlpVWcdu99+5YN7J/dPPqGSSEU0GWZrBfKlK6plrSRw6FnX/pO60MBXZbJcqFM6ZrqiaZV4r7XpLOv6Ghjo92jB32nVac6dFkmyyRNqmuupu7IYUt4vPfsKzrgCJ4O5l36TqtLPXRZJsuFsjzrmsfpNH9UnzX07CtkMWjVqleTArqsMOxseXnVNY9T6maUnzWpTLH37CskWKtWvZoU0CVYtCe5qjWJGSw+efypXmVowOhnHCpt0mbELOqzhp59pS0GDRpFXGUK6BKktye5uHT8qce6vcrrLz+X6y8/N3P6oOlD0nv3ZZyiPmvI2VfcH+buhdHphqe/6k4BXYL0y6t2e5X7Zzdn/mVv+pD0kBx1mZ9VA47qKyigm9nFwN8CE8BN7r6j5/E3A+/u3PwJ8Efufk+eDZVyhfQY8+pV9kvd1P2Cab/9VIWUhlYdqqe+Ad3MJoAbgNcB88ABM9vj7g9GNvs28Bp3f8LMLgF2ARcU0WApR7+8anebqGEDb1oPsQkXTNP2pVIakkVID/184Ii7PwxgZrcAlwFPBXR3/3Jk+7uAs/JspJSv31qlvb3KrIE3qYfYhAummhRLihIysGgaeCRye75zX5J3AJ+Pe8DMtpnZnJnNHTt2LLyVUrrexTKmWpOcftpk4sruRQ0ySkpXLCwu5baEWtHLsiUtPJJnMNfScuMppIceN8Asds5dM7uIdkB/ddzj7r6LdjqGmZmZcubtlaENklctqlIlLV2RR/plVCmdInPUTUhLyXBCeujzwNmR22cBR3s3MrOXADcBl7n7D/NpntRVUpVG1uqNuKkJorKeBTRh+oImfAYZTkgP/QBwjpmtBxaAK4A3RTcws7XArcBb3P0bubdSameYQUYhF1H7TTYF2c4C0s4syqyuGeS9m17HL8n6BnR3P2Fm1wB7aZct3uzuD5jZ1Z3HbwTeCzwX+KiZAZxIWlFDmict2CRVqvTeDwSnCbrpiqTVeLKcBSSldFa1Jle077rdh7h296EVlSl5B/5BUyhNr+OXZFqCTjKJG/WYVrERt310etZe01OtxAUYBn3vLJ/nmZOn8MSTxxOf131fWLkA97BtSpseAJL3TRH7RapDS9BJYQYtI4zbPq1LkZbqKGJEY9JrXrf7UOrzojnqPMoqs0wPoJGe40sBXTIZNF87aB43LtURTTfkVS3SL02S1lPuSvtsg37urNMDaKTneNICF5LJoNUsg+RxW5MTmCX3eEP1q8mOLujgPP1HI7pdv+oaaH+2vKp76jA9gFSPArpkMugKRyGBEZ4ebLOYkLceZI3TfsE6pMwvOhgIVg7O6H7mLCs+RaX9AShiIJI0g1Iuksmg+dressPeC6K9F++SUh2hPd6QHH9oeiiaxghJ0UTnjb9u9yF27j0cnMsOnR6g7hOVSb4U0CWzQfO1gwTGrItmhATrYcr80j5z97EsIzZD/lBqRKj0UkCXUvX7Y5C1YiMkWOe10lKvrBOJ9ds3TZioTPKlgC6Vl6ViIyRYF1Xm128isazvoRGh0ksBXRotNFgXUeZX9ERiSa/vkMsfDKkfjRQVKUjI4KC0kbBZX1+jQ5tJI0VFOkZZFVLURGLRz7CqNZk4LYHy6eNHAV1GogrldWVUhQw6kVi//dT7GRaXjqfW9SufPl40sEgKFzK4ZxTKnCc8ZMBRlkFQExa3Do1mWBw3CuhSuFEF0n5D/MusCglZdi5kPyW19aR7LiNUpd6UcpHCjSKQhqRTyp4nvF8lTZZBUN052ctOa0m5FNClcKMIpCGDbIpaRSkvWQdBaYZFUcpFCpfXhFVpQnq3IWmPqFHn/kP206CfQcaLeuhSuFEsuBB6FjBIL3bUQ+vLHAQlzaCALiNRdBAqIp0SmvvPMy2jYC1ZKKBLIwx6FpDXRVTNeChVooAujZF3OiWk168ZD6VKFNBlLIVeRIX0Xr9mPJQqUUCXRuqX187rImrZte0iUSpblMYZdtHnYUopR1GSKRJKAV0aZ9BFn7PUc6suXKpEKRdpnGEWfc5CpYZSFeqhS+Mk5a+V15amCwroZnaxmR02syNmNhvzuJnZhzuP32tm5+XfVJEwymvLuOqbcjGzCeAG4HXAPHDAzPa4+4ORzS4Bzun8uwD4WOd/kZEbxVQDIlUUkkM/Hzji7g8DmNktwGVANKBfBnzS2wuU3mVmU2Z2prs/mnuLRQIory3jKCTlMg08Erk937lv0G0ws21mNmdmc8eOHRu0rSIikiIkoMetbeVDbIO773L3GXefWb16dUj7REQkUEhAnwfOjtw+Czg6xDYiIlKgkIB+ADjHzNab2anAFcCenm32AG/tVLu8EviR8uciIqPV96Kou58ws2uAvcAEcLO7P2BmV3cevxG4A7gUOAI8CVxVXJNFRCSOtQtTSnhjs2PAd4d8+hnAD3JsTl6q2i6obtvUrsGoXYNpYrt+xd1jL0KWFtCzMLM5d58pux29qtouqG7b1K7BqF2DGbd2aei/iEhDKKCLiDREXQP6rrIbkKCq7YLqtk3tGozaNZixalctc+giIrJSXXvoIiLSQwFdRKQhKhvQzewPzOwBM/uFmSWW9yTN1W5mzzGzL5jZNzv/n55Tu/q+rpltMLNDkX8/NrNrO4+9z8wWIo9dOqp2dbb7jpnd13nvuUGfX0S7zOxsM/uimT3U+c7fFXks1/2VZW7/fs8tuF1v7rTnXjP7spm9NPJY7Hc6onZdaGY/inw/7w19bsHt2h5p0/1mdtLMntN5rMj9dbOZPWZm9yc8Xuzx5e6V/Ae8CNgAfAmYSdhmAvgW8KvAqcA9wIs7j/0VMNv5eRb4QE7tGuh1O238X9qDAQDeB/xpAfsrqF3Ad4Azsn6uPNsFnAmc1/n52cA3It9jbvsr7XiJbHMp8HnaE869EvhK6HMLbtergNM7P1/SbVfadzqidl0IfG6Y5xbZrp7tXw/sK3p/dV77N4HzgPsTHi/0+KpsD93dH3L3w302e2qudnf/OdCdq53O/5/o/PwJYGtOTRv0dV8LfMvdhx0VGyrr5y1tf7n7o+7+tc7P/wc8RMz0yzlIO16i7f2kt90FTJnZmYHPLaxd7v5ld3+ic/Mu2hPgFS3LZy51f/W4Evh0Tu+dyt3vBB5P2aTQ46uyAT1Q2jzsz/fOBGGd/5+X03sO+rpXsPJguqZzunVzXqmNAdrlwL+b2UEz2zbE84tqFwBmtg7YCHwlcnde+yvL3P5Bc/4X2K6od9Du5XUlfaejatdvmNk9ZvZ5M/v1AZ9bZLsws9OAi4HPRO4uan+FKPT4ClmxqDBm9h/AC2Ie+jN3/2zIS8Tcl7kOM61dA77OqcDvAu+J3P0x4P202/l+4K+BPxxhuza5+1Ezex7wBTP7eqdXMbQc99ezaP/iXevuP+7cPfT+inuLmPtC5/Yv5Fjr854rNzS7iHZAf3Xk7ty/0wHa9TXa6cSfdK5v3EZ7KcpK7C/a6Zb97h7tNRe1v0IUenyVGtDd/bcyvkTaPOzft84yeJ1TmsfyaJeZDfK6lwBfc/fvR177qZ/N7O+Az42yXe5+tPP/Y2b2L7RP9e6k5P1lZpO0g/mn3P3WyGsPvb9iZJnb/9SA5xbZLszsJcBNwCXu/sPu/SnfaeHtivzhxd3vMLOPmtkZIc8tsl0RK86QC9xfIQo9vuqeckmbq30P8LbOz28DQnr8IQZ53RW5u05Q63oDEHs1vIh2mdkvmdmzuz8Dvx15/9L2l5kZ8PfAQ+7+Nz2P5bm/ssztH/LcwtplZmuBW4G3uPs3IvenfaejaNcLOt8fZnY+7Zjyw5DnFtmuTntWAa8hcswVvL9CFHt8FXGlN49/tH9554GfAd8H9nbuXwPcEdnuUtpVEd+inarp3v9c4D+Bb3b+f05O7Yp93Zh2nUb7wF7V8/x/AO4D7u18YWeOql20r6Df0/n3QFX2F+30gXf2yaHOv0uL2F9xxwtwNXB152cDbug8fh+RCqukYy2n/dSvXTcBT0T2z1y/73RE7bqm87730L5Y+6oq7K/O7bcDt/Q8r+j99WngUeA47fj1jlEeXxr6LyLSEHVPuYiISIcCuohIQyigi4g0hAK6iEhDKKCLiDSEArqISEMooIuINMT/A8M3OBg3OfVRAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(x.data.numpy(), y.data.numpy())\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, n_feature, n_hidden, n_output):\n",
    "        super().__init__()\n",
    "        self.hidden = torch.nn.Linear(n_feature, n_hidden)\n",
    "        self.predict = torch.nn.Linear(n_hidden, n_output)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.hidden(x))\n",
    "        x = self.predict(x)\n",
    "        return x\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (hidden): Linear(in_features=1, out_features=10, bias=True)\n",
      "  (predict): Linear(in_features=10, out_features=1, bias=True)\n",
      ") Sequential(\n",
      "  (0): Linear(in_features=1, out_features=10, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=10, out_features=1, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net1 = Net(1, 10, 1)\n",
    "net2 = torch.nn.Sequential(\n",
    "    torch.nn.Linear(1, 10),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(10, 1),\n",
    ")\n",
    "print(net1, net2)\n",
    "net = net2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.1)\n",
    "loss_func = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.7230, grad_fn=<MseLossBackward>)\n",
      "tensor(0.1017, grad_fn=<MseLossBackward>)\n",
      "tensor(0.1476, grad_fn=<MseLossBackward>)\n",
      "tensor(0.3119, grad_fn=<MseLossBackward>)\n",
      "tensor(0.3008, grad_fn=<MseLossBackward>)\n",
      "tensor(0.0817, grad_fn=<MseLossBackward>)\n",
      "tensor(0.1225, grad_fn=<MseLossBackward>)\n",
      "tensor(0.2529, grad_fn=<MseLossBackward>)\n",
      "tensor(0.3598, grad_fn=<MseLossBackward>)\n",
      "tensor(0.3129, grad_fn=<MseLossBackward>)\n",
      "tensor(0.1922, grad_fn=<MseLossBackward>)\n",
      "tensor(0.0462, grad_fn=<MseLossBackward>)\n",
      "tensor(0.1993, grad_fn=<MseLossBackward>)\n",
      "tensor(0.3226, grad_fn=<MseLossBackward>)\n",
      "tensor(0.1563, grad_fn=<MseLossBackward>)\n",
      "tensor(0.0766, grad_fn=<MseLossBackward>)\n",
      "tensor(0.2556, grad_fn=<MseLossBackward>)\n",
      "tensor(0.3112, grad_fn=<MseLossBackward>)\n",
      "tensor(0.1261, grad_fn=<MseLossBackward>)\n",
      "tensor(0.0945, grad_fn=<MseLossBackward>)\n"
     ]
    }
   ],
   "source": [
    "for t in range(100):\n",
    "    prediction = net(x)\n",
    "    \n",
    "    loss = loss_func(prediction, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    if t%5 == 0:\n",
    "        print(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.7 64-bit",
   "language": "python",
   "name": "python37764bita7c2719225b84763be647c75e40e67b2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
